# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from math import ceil

import numpy as np
from scipy.fft import irfft, rfft, rfftfreq

from ..utils import logger, verbose


@verbose
def stft(x, wsize, tstep=None, verbose=None):
    """STFT Short-Term Fourier Transform using a sine window.

    The transformation is designed to be a tight frame that can be
    perfectly inverted. It only returns the positive frequencies.

    Parameters
    ----------
    x : array, shape (n_signals, n_times)
        Containing multi-channels signal.
    wsize : int
        Length of the STFT window in samples (must be a multiple of 4).
    tstep : int
        Step between successive windows in samples (must be a multiple of 2,
        a divider of wsize and smaller than wsize/2) (default: wsize/2).
    %(verbose)s

    Returns
    -------
    X : array, shape (n_signals, wsize // 2 + 1, n_step)
        STFT coefficients for positive frequencies with
        ``n_step = ceil(T / tstep)``.

    See Also
    --------
    istft
    stftfreq
    """
    if not np.isrealobj(x):
        raise ValueError("x is not a real valued array")

    if x.ndim == 1:
        x = x[None, :]

    n_signals, T = x.shape
    wsize = int(wsize)

    # Errors and warnings
    if wsize % 4:
        raise ValueError("The window length must be a multiple of 4.")

    if tstep is None:
        tstep = wsize / 2

    tstep = int(tstep)

    if (wsize % tstep) or (tstep % 2):
        raise ValueError(
            "The step size must be a multiple of 2 and a divider of the window length."
        )

    if tstep > wsize / 2:
        raise ValueError("The step size must be smaller than half the window length.")

    n_step = int(ceil(T / float(tstep)))
    n_freq = wsize // 2 + 1
    logger.info(f"Number of frequencies: {n_freq}")
    logger.info(f"Number of time steps: {n_step}")

    X = np.zeros((n_signals, n_freq, n_step), dtype=np.complex128)

    if n_signals == 0:
        return X

    # Defining sine window
    win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
    win2 = win**2

    swin = np.zeros((n_step - 1) * tstep + wsize)
    for t in range(n_step):
        swin[t * tstep : t * tstep + wsize] += win2
    swin = np.sqrt(wsize * swin)

    # Zero-padding and Pre-processing for edges
    xp = np.zeros((n_signals, wsize + (n_step - 1) * tstep), dtype=x.dtype)
    xp[:, (wsize - tstep) // 2 : (wsize - tstep) // 2 + T] = x
    x = xp

    for t in range(n_step):
        # Framing
        wwin = win / swin[t * tstep : t * tstep + wsize]
        frame = x[:, t * tstep : t * tstep + wsize] * wwin[None, :]
        # FFT
        X[:, :, t] = rfft(frame)

    return X


def istft(X, tstep=None, Tx=None):
    """ISTFT Inverse Short-Term Fourier Transform using a sine window.

    Parameters
    ----------
    X : array, shape (..., wsize / 2 + 1, n_step)
        The STFT coefficients for positive frequencies.
    tstep : int
        Step between successive windows in samples (must be a multiple of 2,
        a divider of wsize and smaller than wsize/2) (default: wsize/2).
    Tx : int
        Length of returned signal. If None Tx = n_step * tstep.

    Returns
    -------
    x : array, shape (Tx,)
        Array containing the inverse STFT signal.

    See Also
    --------
    stft
    """
    # Errors and warnings
    X = np.asarray(X)
    if X.ndim < 2:
        raise ValueError(f"X must have ndim >= 2, got {X.ndim}")
    n_win, n_step = X.shape[-2:]
    signal_shape = X.shape[:-2]
    if n_win % 2 == 0:
        raise ValueError("The number of rows of the STFT matrix must be odd.")

    wsize = 2 * (n_win - 1)
    if tstep is None:
        tstep = wsize / 2

    if wsize % tstep:
        raise ValueError(
            "The step size must be a divider of two times the "
            "number of rows of the STFT matrix minus two."
        )

    if wsize % 2:
        raise ValueError("The step size must be a multiple of 2.")

    if tstep > wsize / 2:
        raise ValueError(
            "The step size must be smaller than the number of "
            "rows of the STFT matrix minus one."
        )

    if Tx is None:
        Tx = n_step * tstep

    T = n_step * tstep

    x = np.zeros(signal_shape + (T + wsize - tstep,), dtype=np.float64)

    if np.prod(signal_shape) == 0:
        return x[..., :Tx]

    # Defining sine window
    win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
    # win = win / norm(win);

    # Pre-processing for edges
    swin = np.zeros(T + wsize - tstep, dtype=np.float64)
    for t in range(n_step):
        swin[t * tstep : t * tstep + wsize] += win**2
    swin = np.sqrt(swin / wsize)

    for t in range(n_step):
        # IFFT
        frame = irfft(X[..., t], wsize)
        # Overlap-add
        frame *= win / swin[t * tstep : t * tstep + wsize]
        x[..., t * tstep : t * tstep + wsize] += frame

    # Truncation
    x = x[..., (wsize - tstep) // 2 : (wsize - tstep) // 2 + T + 1]
    x = x[..., :Tx].copy()
    return x


def stftfreq(wsize, sfreq=None):  # noqa: D401
    """Compute frequencies of stft transformation.

    Parameters
    ----------
    wsize : int
        Size of stft window.
    sfreq : float
        Sampling frequency. If None the frequencies are given between 0 and pi
        otherwise it's given in Hz.

    Returns
    -------
    freqs : array
        The positive frequencies returned by stft.

    See Also
    --------
    stft
    istft
    """
    freqs = rfftfreq(wsize)
    if sfreq is not None:
        freqs *= float(sfreq)
    return freqs


def stft_norm2(X):
    """Compute L2 norm of STFT transform.

    It takes into account that stft only return positive frequencies.
    As we use tight frame this quantity is conserved by the stft.

    Parameters
    ----------
    X : 3D complex array
        The STFT transforms

    Returns
    -------
    norms2 : array
        The squared L2 norm of every row of X.
    """
    X2 = (X * X.conj()).real
    # compute all L2 coefs and remove first and last frequency once.
    norms2 = (
        2.0 * X2.sum(axis=2).sum(axis=1)
        - np.sum(X2[:, 0, :], axis=1)
        - np.sum(X2[:, -1, :], axis=1)
    )
    return norms2


def stft_norm1(X):
    """Compute L1 norm of STFT transform.

    It takes into account that stft only return positive frequencies.

    Parameters
    ----------
    X : 3D complex array
        The STFT transforms

    Returns
    -------
    norms : array
        The L1 norm of every row of X.
    """
    X_abs = np.abs(X)
    # compute all L1 coefs and remove first and last frequency once.
    norms = (
        2.0 * X_abs.sum(axis=(1, 2))
        - np.sum(X_abs[:, 0, :], axis=1)
        - np.sum(X_abs[:, -1, :], axis=1)
    )
    return norms
